import networkx as nx
import random
import simpy
import math
from .bgp import BgpProcess, BgpUpdate, BgpMessage, WITHDRAW, ANNOUNCE
from .fib import FibRecorder

class FatTreeProvider(object):
    def __init__(self, config):
        config_k = 'k'

        try:
            K = config[config_k].get(int)
            assert K == 1 << int(math.log2(K))

            self.K = K
        except Exception as e:
            raise e

    def generate(self):
        g = nx.MultiGraph()

        K = self.K

        n_cores = int(K * K / 4)
        cores = ['c%d' % (i + 1) for i in range(n_cores)]

        for i in range(n_cores):
            # core switches
            g.add_node(cores[i], switch_type='core', pos = (i, 3))

        n_pods = K
        aggr_pod = int(K / 2)
        tor_pod = int(K / 2)
        aggrs = [['a%d-%d' % (i + 1, j + 1) for j in range(aggr_pod)] for i in range(n_pods)]
        tors = [['t%d-%d' % (i + 1, j + 1) for j in range(tor_pod)] for i in range(n_pods)]

        for i in range(n_pods):
            for j in range(aggr_pod):
                # aggregation switches
                pos = (i * aggr_pod + j, 2)
                g.add_node(aggrs[i][j], switch_type='aggr', pod=i, pos=pos)
            for j in range(tor_pod):
                # ToR switches
                pos = (i * aggr_pod + j, 1)
                g.add_node(tors[i][j], switch_type='tor', pod=i, pos=pos)

        for i in range(n_pods):
            for j in range(aggr_pod):
                for k in range(aggr_pod):
                    ci = k * aggr_pod + j
                    # core-aggr links
                    g.add_edge(cores[ci], aggrs[i][j])

        for i in range(n_pods):
            for j in range(aggr_pod):
                for k in range(tor_pod):
                    # aggr-tor links
                    g.add_edge(aggrs[i][j], tors[i][k])

        self.n_cores = n_cores
        self.cores = cores
        self.n_pods = n_pods
        self.aggr_pod = aggr_pod
        self.tor_pod = tor_pod
        self.aggrs = aggrs
        self.tors = tors

        return g

class FbNetworkProvider(object):
    def __init__(self, config):
        config_npod = 'npod'

        try:
            self.npod = config[config_npod].get(int)
        except Exception as e:
            raise e

    def generate(self):
        g = nx.MultiGraph()

        npod = self.npod
        rsw_pod = 48
        fsw_pod = 4
        nspine = 4
        ssw_spine = 48

        rsw = [['rsw-%d-%d' % (i, j) for j in range(rsw_pod)]
               for i in range(npod)]
        hosts = [['host-%d-%d' % (i, j) for j in range(rsw_pod)]
                 for i in range(npod)]
        fsw = [['fsw-%d-%d' % (i, j) for j in range(fsw_pod)]
               for i in range(npod)]
        for i in range(npod):
            for j in range(rsw_pod):
                g.add_node('rsw-%d-%d' % (i, j), switch_type='tor', pod=i)
                g.add_node('h-%d-%d' % (i, j), switch_type='host', pod=i)
            for j in range(fsw_pod):
                g.add_node('fsw-%d-%d' % (i, j), switch_type='aggr', pod=i)

        for i in range(nspine):
            for j in range(ssw_spine):
                g.add_node('ssw-%d-%d' % (i, j), switch_type='core')

        for i in range(npod):
            for j in range(rsw_pod):
                g.add_edge(rsw[i][j], hosts[i][j])
            for j in range(fsw_pod):
                for k in range(rsw_pod):
                    g.add_edge(fsw[i][j], rsw[i][k])
                for k in range(ssw_spine):
                    g.add_edge(fsw[i][j], ssw[j][k])
        return g

class NetworkProvider(object):
    def __init__(self, config):
        config_provider = 'provider'
        config_type = 'type'
        config_fib_recorder = 'fib_recorder'

        try:
            provider = config[config_provider][config_type].get(str)

            if provider == 'fattree':
                self.provider = FatTreeProvider(config[config_provider])
            elif provider == 'fb':
                self.provider = FbNetworkProvider(config[config_provider])
            self.config = config
            self.fpm = FibRecorder(config[config_fib_recorder])
        except Exception as e:
            raise e

        self.g = self.provider.generate()
        self.env = None
        self.devices = {}

    def configure(self, env):
        self.env = env
        self.fpm.configure(env)
        for n, attributes in self.g.nodes(data=True):
            print('Create process for %s' % (str(n)))
            p = BgpProcess(env, n, **attributes)
            p.set_fpm(self.fpm.channel)
            self.devices[n] = p
            env.process(p.main_loop())

    def start(self):
        env = self.env
        for u, v in self.g.edges():
            src = self.devices[u]
            dst = self.devices[v]
            spod = src.attributes.get('pod', -1)
            dpod = dst.attributes.get('pod', -1)
            if spod == 0 and dpod != 0:
                continue
            if spod != 0 and dpod == 0:
                continue
            src.add_peer(v, channel=dst.channel)
            dst.add_peer(u, channel=src.channel)
        for n in self.devices:
            dev = self.devices[n]
            yield env.timeout(random.randint(6, 10))
            dev.channel.put((env.now, 'hello'))
            env.process(dev.say_hello())
        idx = 0
        for n in self.devices:
            dev = self.devices[n]
            if dev.attributes['switch_type'] != 'tor':
                continue
            yield env.timeout(random.randint(6, 10))
            updates = [BgpUpdate(ANNOUNCE, '192.168.%d.0/24' % (idx), [])]
            msg = BgpMessage(env.now, n, updates)
            dev.channel.put(msg)
            idx += 1

        env.process(self.add_new_pod(timeout=5000))

    def add_new_pod(self, pod=0, timeout=5000):
        env = self.env
        yield env.timeout(timeout)
        for u, v in self.g.edges():
            src = self.devices[u]
            dst = self.devices[v]
            spod = src.attributes.get('pod', -1)
            dpod = dst.attributes.get('pod', -1)
            if spod != pod and dpod != pod:
                continue
            src.add_peer(v, channel=dst.channel)
            dst.add_peer(u, channel=src.channel)

    def draw(self, filename, g=None):
        import matplotlib.pyplot as plt

        if g is None:
            g = self.g
        pos = nx.get_node_attributes(g, 'pos')
        nx.draw(g, pos)

        plt.savefig(filename)

    def dump_fib(self):
        for n in self.devices:
            dev = self.devices[n]
            dev.dump_fib()
